#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/variable.h>

#include <ATen/ATen.h>
#include <ATen/Utils.h>
#include <c10/core/TensorOptions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/core/Reduction.h>
#include <ATen/BatchedTensorImpl.h>
#include <ATen/Dispatch.h>
#include <ATen/ScalarOps.h>
#include <ATen/SparseTensorUtils.h>

#include <ciso646>
#include <algorithm>
#include <numeric>
#include <functional>
// Helper functions for autogenerated code
// These used to be inlined into the codegened Functions.cpp

namespace torch {
namespace autograd {
namespace generated {
namespace details {

using at::Tensor;
using at::Scalar;
using at::IntArrayRef;
using at::TensorList;

bool isDefined(const c10::optional<Tensor>& t) {
  return t.has_value() && t->defined();
}

bool isFwGradDefined(const c10::optional<Tensor>& t) {
  return t.has_value() && t->defined() && t->fw_grad(/*level */ 0).defined();
}

Tensor toLegacyTensor(const c10::optional<Tensor>& t) {
  return t.has_value() ? *t : Tensor();
}

Tensor toLegacyFwGrad(const c10::optional<Tensor>& t) {
  return (t.has_value() && t->defined()) ? t->fw_grad(/*level */ 0) : Tensor();
}

Tensor toLegacyPrimal(const c10::optional<Tensor>& t) {
  return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor();
}

void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
  AT_ASSERT(range.second <= out.size());
  AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");
  out[range.first] = t;
}

void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
  AT_ASSERT(range.second <= out.size());
  AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output");
  std::copy(t.begin(), t.end(), out.begin() + range.first);
}

Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) {
  auto ratio = result / self;
  ratio.masked_fill_(self == 0, 0);
  return grad * ratio;
}

Tensor not_implemented(const char* name) {
  throw std::runtime_error(
      std::string("the derivative for '") + name + "' is not implemented");
}

Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
  bool is_one = false;
  if (s.isFloatingPoint()) {
    is_one = s.toDouble() == 1;
  } else if(s.isIntegral(true)) {
    is_one = s.toLong() == 1;
  }

  if (is_one) {
    return t;
  } else {
    return t * s;
  }
}

int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
  int64_t size = 1;
  if (sizes.size() == 0) {
    return 1;
  }
  for (auto d : dim) {
    d = at::maybe_wrap_dim(d, sizes.size());
    size *= sizes[d];
  }
  return size;
}

static Tensor wrapped_scalar_tensor(Scalar scalar) {
  auto tensor = scalar_to_tensor(scalar);
  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
  return tensor;
}

Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
  if (!at::isComplexType(self_st) && gradient_result.is_complex()) {
    // R -> C
    return at::real(gradient_result);
  }
  return gradient_result;
}

Tensor handle_r_to_c(Tensor self, Tensor gradient_result) {
  if (!self.is_complex() && gradient_result.is_complex()) {
    // R -> C
    return at::real(gradient_result);
  }
  return gradient_result;
}

Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) {
  if (keepdim) {
    return output;
  }
  int64_t total_dims = output.dim() + dims.size();
  std::vector<int64_t> target_shape(total_dims, 0);
  for (int64_t i : dims) {
    if (i < 0) {
      i = total_dims + i;
    }
    target_shape[i] = 1;
  }
  int64_t j = 0;
  for (int64_t i : output.sizes()) {
    while (target_shape[j] > 0) j++;
    target_shape[j++] = i;
  }
  return output.reshape(target_shape);
}

Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims) {
  return (grad / mask.sum(dims, true)) * mask;
}

std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) {
  if (!grad.defined()) {
    return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
  }
  // handle case at 0 where we return a subgradient containing 0
  Tensor ratio = grad / res;
  ratio.masked_fill_(res == 0, 0);
  return std::tuple<Tensor, Tensor>{
            x1 * ratio.sum(-1, true) - ratio.matmul(x2),
            x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)};
}

Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
  double p = p_.value_or(2.0).toDouble();
  Tensor self_scaled;
  Tensor scale_v;
  if (p == 0.0) {
    return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else if (p == 1.0) {
    return self.sgn() * grad;
  } else if (p == 2.0) {
    self_scaled = self;
    scale_v = grad / norm;
  } else if (std::isinf(p)) {
    self_scaled = self.sgn() * (self.abs() == norm).type_as(self);
    scale_v = grad.clone(at::MemoryFormat::Preserve);
  } else if (p < 2.0) {
    self_scaled = self.sgn() * self.abs().pow(p - 1);
    scale_v = grad / norm.pow(p - 1);
  } else {
    self_scaled = self * self.abs().pow(p - 2);
    scale_v = grad / norm.pow(p - 1);
  }
  // handle case at 0 where we return a subgradient containing 0
  scale_v.masked_fill_(norm == 0, 0);
  return self_scaled * scale_v;
}

Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, IntArrayRef dim, bool keepdim) {
  IntArrayRef sizes = self.sizes();
  if (!keepdim && self.dim() != 0) {
    if (dim.size()==1) {
      grad = grad.unsqueeze(dim[0]);
      norm = norm.unsqueeze(dim[0]);
    } else {
      auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, sizes.size());
      for (size_t i = 0; i < sizes.size(); i++){
        if (dims_to_unsqueeze[i]) {
          grad = grad.unsqueeze(i);
          norm = norm.unsqueeze(i);
        }
      }
    }
  }
  return norm_backward(grad, self, p_, norm);
}

Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
  auto exponent = (exponent_.isComplex()) ? exponent_.toComplexDouble() : exponent_.toDouble();
  if (exponent == 0.0) {
    return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else {
    auto out = grad * (exponent * self.pow(exponent - 1)).conj();
    return handle_r_to_c(self, out);
  }
}

Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
  auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj());
  return handle_r_to_c(self, out);
}

// Caveats:
// We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
// d(a^b)/db -> -inf for a fixed b as a -> +0
// Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
//
// We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
// d(a^b)/db = 0 for a > 0 and b -> +0.
// Currently, tensorflow agrees with us.
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
  Tensor cond;
  if (exponent.is_complex()) {
    auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
    cond = at::logical_and(self == 0, is_real_exp);
  } else {
    cond = at::logical_and(self == 0, exponent >= 0);
  }
  auto out = grad * at::where(cond,
                          at::zeros({}, grad.options()),
                          (result * self.log()).conj());
  return handle_r_to_c(exponent, out);
}

Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) {
  auto base_ = base.isComplex() ? base.toComplexDouble() : base.toDouble();
  auto grad_lambda = [](auto a, auto b) { return (a * std::log(b)).conj(); };
  if (base_ == 0.0) {
    auto cond = [](auto exp) {
      if (exp.is_complex()) {
        return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
      } else {
        return exp >=0;
      }
    };
    auto out = grad * at::where(cond(exponent),
                            at::zeros({}, grad.options()),
                            grad_lambda(result, base_));
    return handle_r_to_c(exponent, out);
  } else {
    auto out = grad * grad_lambda(result, base_);
    return handle_r_to_c(exponent, out);
  }
}

Tensor angle_backward(Tensor grad, const Tensor& self) {
  if (self.is_complex()) {
    return at::where(self == 0.0, at::zeros({}, self.options()),
                     grad * self / self.abs().pow(2) * Scalar(c10::complex<double>{0.0, 1.0}));
  } else {
    return at::zeros_like(self, at::MemoryFormat::Preserve);
  }
}

Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
  Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options());
  args = args.add(self.unsqueeze(-1));
  return grad * args.digamma_().sum(-1);
}

Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) {
  if (self.is_complex()) {
    auto abs = at::abs(self);
    // C -> C
    // https://arxiv.org/pdf/1701.00392.pdf Section 4.20
    return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result)));
  } else {
    return at::zeros_like(self, at::MemoryFormat::Preserve);
  }
}

Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) {
  auto out = grad * other.conj();
  return handle_r_to_c(self_st, out);
}

Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) {
  auto result = grad / other.conj();
  return handle_r_to_c(self_st, result);
}

Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) {
  auto result = -grad * ((self / other) / other).conj();
  return handle_r_to_c(other, result);
}

Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
  // invert the permutation
  auto ndims = fwd_dims.size();
  std::vector<int64_t> dims(ndims);
  for (size_t i = 0; i < ndims; i++) {
    dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
  }
  return grad.permute(dims);
}

Tensor rad2deg_backward(const Tensor& grad) {
  constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564;
  return at::mul(grad, wrapped_scalar_tensor(Scalar(M_180_PI)));
}

Tensor deg2rad_backward(const Tensor& grad) {
  constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417;
  return at::mul(grad, wrapped_scalar_tensor(Scalar(M_PI_180)));
}

Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
    Tensor res = t;
    for (size_t i = 0; i < n_dims; i++){
      if (dims_to_unsqueeze[i]) {
        res = res.unsqueeze(i);
      }
    }
    return res;
}

Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
  if (!keepdim && sizes.size() > 0) {
    if (dims.size()==1) {
      return grad.unsqueeze(dims[0]).expand(sizes);
    } else {
      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
      return res.expand(sizes);
    }
  } else {
    return grad.expand(sizes);
  }
}

Tensor nansum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dims, bool keepdim) {
  auto sizes = self.sizes();
  if (!keepdim && sizes.size() > 0) {
    if (dims.size()==1) {
      return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not();
    } else {
      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
      return res.expand(sizes) * self.isnan().logical_not();
    }
  } else {
    return grad.expand(sizes) * self.isnan().logical_not();
  }
}

std::vector<int64_t> reverse_list(const IntArrayRef list) {
  auto result = std::vector<int64_t>();
  result.reserve(list.size());
  for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
    result.push_back(*iter);
  }
  return result;
}

Tensor reverse_dim(const Tensor& t, int64_t dim) {
  Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong));
  return t.index_select(dim, index);
}

Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) {
  if (inp.size(dim) == 1) {
    return grad;
  }

  auto ones_size = inp.sizes().vec();
  ones_size[dim] = 1;
  Tensor ones = at::ones(ones_size, grad.options());
  Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim);
  Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);

  Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim);
  Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
  Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);

  return grad * (exclusive_normal * exclusive_reverse);
}

// note that the gradient for prod is equivalent to:
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
// input:                        [    a,     b,     c]
// cumprod(exclusive, normal):   [1    ,     a, a * b]
// cumprod(exclusive, reverse):  [b * c,     c,     1]
// product:                      [b * c, a * c, a * b]
// and this is safe under input with 0s.
Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) {
  if (input.dim() == 0) {
    return grad;
  }
  Tensor zero_idx = (input == 0).nonzero();
  if (zero_idx.numel() == 0) {
    return (grad * result) / input;
  } else if (zero_idx.size(0) > 1) {
    return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else {
    return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input);
  }
}

Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) {
  if (input.dim() == 0) {
    return grad;
  }
  dim = at::maybe_wrap_dim(dim, input.sizes().size());
  if (!keepdim && input.dim() != 1) {
    grad = grad.unsqueeze(dim);
    result = result.unsqueeze(dim);
  }

  Tensor zero_mask = (input == 0);
  Tensor slice_zero_count = zero_mask.sum(dim, true);
  int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
  if (total_zeros == 0) {
    return (grad * result) / input;
  } else {
    return prod_safe_zeros_backward(grad, input, dim);
  }
}

Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
  return at::linalg_solve(A.conj().transpose(-2, -1), grad);
}

Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
  Tensor grad_self = solve_backward_self(grad, self, A);
  if (self.ndimension() == 2 && A.ndimension() == 2) {
    return -at::mm(grad_self, solution.conj().transpose(-2, -1));
  }
  // if self was unsqueezed from (..., M) to (..., M, 1)
  auto batched_rhs_shape = IntArrayRef(A.sizes().data(), A.dim()-1);  // A.shape[:-1]
  bool is_rhs_broadcasted = self.dim() == 1 || (A.dim()-1 == self.dim() && self.sizes().equals(batched_rhs_shape));
  if (is_rhs_broadcasted) {
    return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).conj().transpose(-2, -1));
  }
  return -at::matmul(grad_self, solution.conj().transpose(-2, -1));
}

Tensor cumsum_backward(const Tensor & x, int64_t dim) {
  // Need to check numel to see if there are no values (such as shape [0,2], and dim to see if x is a scalar.
  if (x.dim() == 0 || x.numel() == 0) {
    return x;
  }
  auto ret = at::cumsum(-x, dim);
  auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone(at::MemoryFormat::Preserve);
  ret -= ret_sum.expand(ret.sizes());
  ret += x;
  return ret;
}

Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
  if (!keepdim && self.dim() != 0) {
    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
    result = unsqueeze_multiple(result, dim, self.sizes().size());
  }
  return grad * (self - result).exp();
}

Tensor logcumsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim) {
  if (grad.dim() == 0 || grad.numel() == 0) {
    return grad;
  }

  // Reference: https://github.com/tensorflow/tensorflow/blob/
  // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
  return AT_DISPATCH_FLOATING_TYPES(
      at::typeMetaToScalarType(grad.dtype()),
      "logcumsumexp_backward",
      [grad, self, result, dim]() {
        auto grad_min = at::empty_like(grad);
        grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
        auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
        auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);

        auto reverse_logcumsumexp = [dim](auto x) {
          return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
        };

        auto output_pos =
            (reverse_logcumsumexp(log_grad_positive - result) + self).exp();
        auto output_neg =
            (reverse_logcumsumexp(log_grad_negative - result) + self).exp();

        return output_pos - output_neg;
      });
}

Tensor unbind_backward(const variable_list& grads, int64_t dim) {
  IntArrayRef sizes;
  at::TensorOptions o;
  for (auto v : grads) {
    if (v.defined()) {
      sizes = v.sizes();
      o = static_cast<Tensor>(v).options();
      break;
    }
  }
  auto grads_tensors = fmap(grads, [&](const Variable& v) {
    return (
        v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes));
  });
  return at::stack(grads_tensors, dim);
}

Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
  auto result = self;

  int64_t nDims = sizes.size();
  for (int64_t dim = 0; dim < nDims; dim++) {
    if (sizes[dim] == 1) {
      result = result.unsqueeze(dim);
    }
  }
  return result;
}

Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
  dim = at::maybe_wrap_dim(dim, sizes.size());
  // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
  // unsqueezing in the backward.
  if (sizes.size() > 0 && sizes[dim] == 1) {
    return self.unsqueeze(dim);
  }
  return self;
}

std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
  std::vector<Tensor> grad_inputs(sizes.size());
  if (!grad.defined()) {
    return grad_inputs;
  }
  dim = at::legacy_cat_wrap_dim(dim, sizes);
  int64_t accumulate = 0;
  for (size_t i = 0; i < sizes.size(); ++i) {
    auto& shape = sizes[i];
    // If input was empty tensor, gradInput should be empty tensor.
    if (shape == std::vector<int64_t>({0})) {
      grad_inputs[i] = at::zeros({0}, grad.options());
      continue;
    }
    auto size = shape[dim];
    accumulate += size;
    grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
  }
  return grad_inputs;
}

Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional<Scalar> & min, const optional<Scalar> & max) {
  // clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
  if (max && min) {
    return grad * ((self >= *min) * (self <= *max)).type_as(grad);
  } else if (min) {
    return grad * (self >= *min).type_as(grad);
  } else if (max) {
    return grad * (self <= *max).type_as(grad);
  } else {
    return grad;
  }
}

// This function is used by load_derivatives.py to replace tensor.strides()
// calls that appear in derivative formulas. If the tensor has requires_grad
// set, this function returns its strides or throws an error if the tensor
// is sparse. If requires_grad is not set, an empty array is returned since
// there will be no backward pass.
//
// This function only supports the case where `input` is the tensor whose
// single derivative is being calculated.
//
// This function does not support `self` derivatives for inplace functions.
//
// Args:
//  input              Tensor to call .strides() on
//  input_name         Name of `input` tensor, from derivative formula
at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name) {
  // TODO: Ideally, this function would never be called if requires_grad is
  // not set. Once codegen is updated to avoid the call, we can remove this
  // check.
  if (input.requires_grad()) {
    TORCH_CHECK(
      !input.is_sparse(),
      "The backward pass for this operation requires the '", input_name,
      "' tensor to be strided, but a sparse tensor was given instead. ",
      "Please either use a strided tensor or set requires_grad=False for '",
      input_name, "'");
    return input.strides();
  } else {
    return IntArrayRef({});
  }
}

Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha) {
  // if input was column-major, return grad as column-order for efficiency
  if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) {
    return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha);
  } else {
    return maybe_multiply(grad.mm(mat2.t().conj()), alpha);
  }
}

Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef sizes, IntArrayRef strides, const Scalar & alpha) {
  // if input was column-major, return grad as column-order for efficiency
  if (strides[0] == 1 && strides[1] == sizes[0]) {
    if (mat1.is_sparse()) {
      // Since mm(dense, sparse) doesn't exist,
      // pass a transposed output matrix to the underlying "addmm"
      // function directly.
      int64_t out_rows = mat1.size(1);
      int64_t out_cols = grad.size(1);
      Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true);
      Tensor r = at::empty({out_cols, out_rows}, grad.options()).t();
      at::addmm_out(r, t, mat1.t(), grad, alpha, 1);
      return r;
    }
    return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha);
  } else {
    return maybe_multiply(mat1.t().conj().mm(grad), alpha);
  }
}

Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_, const Tensor& dense, const Scalar& alpha) {
  AT_ASSERT(sparse_.is_sparse());
  auto sparse = sparse_.coalesce();
  Tensor grad_sparse = maybe_multiply(grad.mm(dense.t()), alpha);
  return grad_sparse.sparse_mask(sparse);
}

// This function return a new SparseTensor with values from Tensor `input` filtered by indices of `mask`
// and values are ignored. `input` and `mask` are sparse matrices, a sparse tensor with sparse_dim=2 and  dense_dim=2,
// and they must have the same shape.
// Note that the `output` must have the same `indices` as the `mask` so we are using just a clone.
// However, to get `values` we have to use specific helper function for CPU/CUDA and use the `mask` data to filter `values`
// That's why we created this `_sparse_matrix_mask_helper` function.
Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){
  Tensor output = at::empty_like(mask);
  Tensor mask_indices = mask._indices().clone();
  Tensor r_values;
  if (mask._nnz() == 0) {
    r_values = at::zeros_like(mask._values());
  } else {
    r_values = _sparse_matrix_mask_helper(input, mask_indices.contiguous());
  }
  at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe(mask_indices, r_values);
  return output;
}

Tensor sparse_sparse_matmul_backward(
    const Tensor& grad,
    const Tensor& a,
    const Tensor& b,
    int64_t grad_order) {
  /*
  To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition
  for dense tensors:

  c = a @ b
      then
  a_grad = c_grad @ b^T
  b_grad = a^T @ c_grad

  So for sparse matrices we can use the following definition:

  if grad_order == 0:
      a_grad = sparse_matrix_mask(c_grad @ b^T, mask=a)
  else:
      b_grad = sparse_matrix_mask(a^T @ c_grad, mask=b)
  */
  TORCH_CHECK(
      grad_order == 0 || grad_order == 1,
      ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
  if (grad_order == 0) {
    auto a_grad = _sparse_sparse_matmul(grad, b.t());
    return _sparse_matrix_mask(a_grad.coalesce(), a.coalesce());
  }
  auto b_grad = _sparse_sparse_matmul(a.t(), grad);
  return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce());
}

Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) {
  auto transposed_sizes = self.transpose(dim, 0).sizes().vec();
  auto flatten = [&](const Tensor & t) {
    return t.transpose(dim, 0).contiguous().view({t.size(dim), -1});
  };
  auto unflatten = [&](const Tensor & t) {
    return t.contiguous().view(transposed_sizes).transpose(dim, 0);
  };

  // renorm computes the norm over all dimensions except `dim`, which is why
  // we need the flatten and unflatten business. TODO: simplify this when we
  // add support for norm over multiple dimensions.
  auto self_flat = flatten(self);
  auto grad_flat = flatten(grad);
  auto norm_flat = self_flat.norm(p, 1, true);
  auto grad_output = (self_flat * grad_flat).sum(1, true);
  auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true);
  auto invnorm = (norm_flat + 1e-7).reciprocal();
  auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb));
  auto norm = unflatten(norm_flat.expand_as(self_flat));

  // TODO: remove the detach once comparison ops no longer require grad
  auto mask = Variable(norm < maxnorm).detach();
  return at::where(mask, grad, grad_norm);
}

Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) {
  auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0);
  if (find_iter != repeats.cend()) {
    return at::zeros(input_shape, grad.options());
  }
  const auto input_dims = input_shape.size();
  int64_t num_unsqueezed = grad.dim() - input_dims;
  for (int64_t i = 0; i < num_unsqueezed; ++i) {
    grad = grad.sum(0, false);
  }

  at::DimVector grad_size, sum_dims;
  for (size_t dim = 0; dim < input_dims; ++dim) {
    int64_t repeat = repeats[dim + num_unsqueezed];
    // Reshape gradient (repeat > 1)
    // Index:      [..., dim    , ...]    [..., dim   ,  dim+1        , ...]
    // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...]
    // The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor.
    // Then, sum up gradients over repeated tensors along 'dim', and reduce shape
    // from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize').
    // Example:
    //        Size(3, 2)                                             Size(6, 2)
    //                                                             [[v1_0, v1_1],
    //                                                              [v1_2, v1_3],
    //        [[v0, v1],                   repeat(2, 1)             [v1_4, v1_5],
    //         [v2, v3],                  ------------->            [v2_0, v2_1],
    //         [v4, v5]]                                            [v2_2, v2_3],
    //                                                              [v2_4, v2_5]]
    //
    //    input grad (3, 2)            reshape (2, 3, 2)         output grad (6, 2)
    //                                  [[[g1_0, g1_1],            [[g1_0, g1_1],
    //                                    [g1_2, g1_3],             [g1_2, g1_3],
    // [[g1_0+g2_0, g1_1+g2_1],           [g1_4, g1_5]],            [g1_4, g1_5],
    //  [g1_0+g2_0, g1_1+g2_1],                                     [g2_0, g2_1],
    //  [g1_0+g2_0, g1_1+g2_1]]          [[g2_0, g2_1],             [g2_2, g2_3],
    //                                    [g2_2, g2_3],             [g2_4, g2_5]]
    //                                    [g2_4, g2_5]]]
    // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then
    // sum over 'dim+1'. The gradient for input is not correctly aligned with input.
    // Example:
    //     input grad (3, 2)            reshape (3, 2, 2)        output grad (6, 2)
    //                                  [[[g1_0, g1_1],
    //                                    [g1_2, g1_3]],           [[g1_0, g1_1],
    //                                                              [g1_2, g1_3],
    // [[g1_0+g1_2, g1_1+g1_3],          [[g1_4, g1_5],             [g1_4, g1_5],
    //  [g1_4+g2_0, g1_5+g2_1],           [g2_0, g2_1]],            [g2_0, g2_1],
    //  [g2_2+g2_4, g2_3+g2_5]]                                     [g2_2, g2_3],
    //                                   [[g2_2, g2_3],             [g2_4, g2_5]]
    //                                    [g2_4, g2_5]]]
    if (repeat != 1) {
      grad_size.push_back(repeat);
      sum_dims.push_back(grad_size.size() - 1);
    }
    // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1)
    grad_size.push_back(input_shape[dim]);
  }
  // One-time Reshape & Sum
  // Reshape gradient to grad_size:
  //   1. If repeat equals to 1, append input size at that dimension,
  //   2. If repeat is larger than 1, append both repeat and input size at that dimension.
  // Sum over all "repeat" dimensions from sum_dims:
  // Example:
  // Input Size         (2,    3,    4,    5)
  // repeat             [4,    1,    9,    3]
  // output/grad Size   (8,    3,    36,   15)
  // grad_size          [4, 2,    3, 9, 4, 3, 5]
  // sum_dims           [0,          3,    5]

  // When repeat 1 time over all original dimensions, the empty sum_dims will reduce
  // the whole grad tensor into a scalar rather than keeping original dimensions.
  if (!sum_dims.empty()) {
    grad = grad.reshape(grad_size);
    grad = grad.sum(sum_dims);
  }
  return grad;
}

// p1m == 1 - p
Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
  if (grad.requires_grad()) {
    // Use autograd-friendly backward if double backward is required
    return grad * (mask.type_as(grad) * (1. / p1m));
  } else {
    return at::_masked_scale(grad, mask, 1. / p1m);
  }
}

Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) {
  if (input.is_cuda()) {
    auto mask = (input == value).logical_or_(input.isnan().logical_and_(value.isnan()));
    return mask * (grad / mask.sum());
  } else {
    auto mask = value.isnan().item<bool>() ? input.isnan() : input == value;
    return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum());
  }
}

Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
  return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
}

Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
  if (self.dim() == 0) {
    return var_backward(grad, self, unbiased);
  }
  if (!keepdim && self.dim() > 1) {
    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
  }
  return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true));
}

Tensor std_backward(const Tensor & result, const Tensor & grad, const Tensor & self, bool unbiased) {
  return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, unbiased);
}

Tensor std_backward(const Tensor & result, Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
  return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, dim, unbiased, keepdim);
}

Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) {
  return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim);
}

Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int numel) {
  return grad.expand(sizes) / numel;
}

Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, IntArrayRef dim, bool unbiased, bool keepdim, bool is_std) {
  Tensor grad;
  if (grads[0].defined()) {
    grad = is_std ? std_backward(r1, grads[0], self, dim, unbiased, keepdim) : var_backward(grads[0], self, dim, unbiased, keepdim);
  }
  if (grads[1].defined()) {
    Tensor mean_grad = mean_backward(grads[1], self.sizes(), dim, keepdim);
    grad = grads[0].defined() ? grad + mean_grad : mean_grad;
  }
  return grad;
}

Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, bool unbiased, bool is_std) {
  Tensor grad;
  if (grads[0].defined()) {
    grad = is_std ? std_backward(r1, grads[0], self, unbiased) : var_backward(grads[0], self, unbiased);
  }
  if (grads[1].defined()) {
    Tensor mean_grad = mean_backward(grads[1], self.sizes(), self.numel());
    grad = grads[0].defined() ? grad + mean_grad : mean_grad;
  }
  return grad;
}

Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) {
  int64_t numel = 1;
  for (auto size : sizes) {
    numel *= size;
  }
  auto mask_selected = grad.masked_select(mask);
  auto diff_nelem = numel - mask_selected.numel();
  if (diff_nelem > 0) {
    // because mask_selected returns a 1-d tensor with size of masked elements that are 1,
    // we need to fill out the rest with zeros then reshape back to tensor2's size.
    auto zeros_fillin = at::zeros({diff_nelem}, grad.options());
    mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
  }
  return mask_selected.view(sizes);
}

Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
  // cf. Iain Murray (2016); arXiv 1602.07527
  // This gradient is symmetric, and not triangular.
  // Cholesky additionally assumes that the input is symmetric, which is a subspace of
  // R^{n x n}, and hence the derivative is not well-defined for off-diagonal
  // elements. We resolve this by taking the gradient of the functionally independent
  // elements of the matrix (i.e., the lower triangular portion of the input) and then
  // reflect it on the upper triangular portion, thereby symmetrizing the gradient of
  // the cholesky operation. The motivation behind this choice is that symmetric gradient
  // leads to stable gradient updates, and retains symmetry of the updated matrix if it
  // were updated by a gradient based algorithm.
  if (upper) {
    L = L.transpose(-1, -2).conj();
    grad = grad.transpose(-1, -2).conj();
  }
  auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false));
  auto phi = at::matmul(L.transpose(-1, -2).conj(), grad);
  phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);

  auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2).conj(), phi), L_inverse);
  return grad_input.add(grad_input.transpose(-1, -2).conj()).mul_(0.5);  // Symmetrizing the gradient
}

Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) {
  Tensor grad_L;
  if (grad.defined()) {
    Tensor common_term = grad + grad.transpose(-2, -1);
    common_term = at::matmul(inverse, at::matmul(common_term, inverse));
    if (upper) {
      grad_L = -at::matmul(L, common_term);
    } else {
      grad_L = -at::matmul(common_term, L);
    }
  } else {
    grad_L = at::zeros({1}, L.options()).expand_as(L);
  }
  return grad_L;
}

Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
                                 IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
  dim = at::maybe_wrap_dim(dim, sizes.size());

  // it's possible some of the grads are not defined (represents tensors of all 0s).
  // Since at::cat can't handle those, let's define them
  std::vector<Tensor> grads_all_defined(grads.size());
  for (size_t j = 0; j < grads.size(); ++j) {
    if (grads[j].defined()) {
      grads_all_defined[j] = grads[j];
    } else {
      auto length = split_sizes[j];
      auto grad_size = sizes.vec();
      grad_size[dim] = length;
      grads_all_defined[j] = at::zeros(grad_size, options);
    }
  }

  auto ret =  at::cat(grads_all_defined, dim);
  return ret;
}

Tensor split_backward(const std::vector<torch::autograd::Variable> &grads,
                      int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
  dim = at::maybe_wrap_dim(dim, sizes.size());
  int64_t dim_size = sizes[dim];
  int64_t num_splits = grads.size();
  std::vector<int64_t> split_sizes(num_splits, split_size);
  split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size);
  return split_with_sizes_backward(grads, split_sizes, dim, sizes, options);
}

Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) {
  AT_ASSERT(indices.dim() >= dim);
  auto size = indices.sizes().slice(0, indices.dim() - dim).vec();
  size.push_back(-1);
  auto indices_view = indices.view(size);
  const auto memory_format = indices.suggest_memory_format();
  return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes());
}

Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) {
  auto& gO = grad_output;
  auto input_size = input.size(dim) / 2;
  auto first_half = input.narrow(dim, 0, input_size);
  auto second_half = input.narrow(dim, input_size, input_size);
  auto sig_second_half = second_half.sigmoid();
  auto one_sub_sig_second_half = 1 - sig_second_half;
  auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;

  auto ggI_first_half = grad.narrow(dim, 0, input_size);
  auto ggI_second_half = grad.narrow(dim, input_size, input_size);
  auto ggI_second_half_times_first_half = ggI_second_half * first_half;

  auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
  auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig;
  auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig;
  return at::cat({gI_first_half, gI_second_half}, dim);
}

Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) {
  if (dim < 0) dim += input.dim();
  auto sizes = input.sizes().vec();
  sizes[dim] /= 2;
  auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
  return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]);
}

Tensor infinitely_differentiable_silu_backward(
    const Tensor& grad_output,
    const Tensor& input) {
  const Tensor sigmoid = input.sigmoid();
  return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
}

Tensor infinitely_differentiable_logit_backward(
    const Tensor& grad,
    const Tensor& self,
    c10::optional<double> eps) {
  if (eps) {
    const double lo = eps.value();
    const double hi = 1.0 - lo;
    return at::where(
        at::logical_and(self >= lo, self <= hi),
        grad / (self * (1.0 - self)),
        at::zeros({}, self.options()));
  } else {
    return at::where(
        at::logical_and(self >= 0.0, self <= 1.0),
        grad / (self * (1.0 - self)),
        at::empty({}, self.options())
            .fill_(std::numeric_limits<double>::quiet_NaN()));
  }
}

Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) {
  auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target);
  if (reduction == at::Reduction::Mean) {
    return result.mean();
  } else if (reduction == at::Reduction::Sum) {
    return result.sum();
  }
  return result;
}

// Compute derivatives for targets.
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) {
  Tensor grad_target;
  if (!log_target) {
    grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
  }
  else {
    grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp()));
  }

  if (reduction == at::Reduction::Mean) {
    grad_target.div_(target.numel());
  }

  return grad_target;
}

Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const c10::optional<Tensor>& weight, const c10::optional<Tensor>& pos_weight, int64_t reduction) {
  Tensor grad_target;
  if (isDefined(pos_weight)) {
    grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight->mul(self.sigmoid().log_())).mul_(grad_output);
  } else {
    grad_target = self.mul(-grad_output);
  }

  if (isDefined(weight)) {
    grad_target.mul_(*weight);
  }

  if (reduction == at::Reduction::Mean) {
    grad_target.div_(target.numel());
  }

  return grad_target;
}

Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) {
  auto z = input.sigmoid();
  return grad * (z - 1) * z;
}

Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
  auto gO = grad_output;
  auto ggI = grad;

  auto ggI_output = ggI * output;
  auto ggI_out_sum = ggI_output.sum(dim, true);
  auto ggI_out_sum_output = ggI_out_sum * output;
  auto gO_out_sum = (gO * output).sum(dim, true);

  // gI calculation
  auto gI_t0 = ggI_output * (gO - gO_out_sum);
  auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum));
  auto gI_t2 = ggI_out_sum_output * gO;
  auto gI_t3 = ggI_out_sum_output * gO_out_sum;
  return gI_t0 - gI_t1 - gI_t2 + gI_t3;
}

Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
  auto z = output.exp();
  return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad);
}

// NOTE: [How to write vmap-compatible backward formulas]
//
// See NOTE: [vmap-incompatible in-place operations] for what it means for an
// in-place operation to be incompatible with vmap.
//
// If an in-place operation used in a backward formula is vmap-incompatible,
// then as developers we have the following options:
//
// - If the in-place operation directly followed the creation of a tensor with
//   a factory function like at::zeros(...), we should replace the factory with a
//   corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
//   propagates the batch dims to the resulting tensor.
//   For example:
//     Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
//     After:  grad.new_zeros(input.sizes()).copy_(grad)
//
// - If the in-place operation followed some sequence of operations, if the
//   we want to be able to vmap over the backward formula as-is (this is
//   usually the case for simple (<15loc) backward formulas), then use
//   inplaceIsVmapCompatible to guard the operation. For example:
//             c = a * b
//     Before: c.mul_(grad)
//     After:  c = at::inplaceIsVmapCompatible(c, grad) ? c.mul_(grad) : c * grad
//
// - If we don't want to vmap directly over the backward formula (e.g., if the
//   backward formula is too complicated or has a lot of vmap-incompatible
//   operations, then register the backward formula as an operator and eventually
//   write a batching rule for it.

Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
  auto eps = 1e-12;
  auto inp_pl_eps = input + eps;
  auto one_m_inp_pl_eps = 1 - input + eps;
  // gradient wrt input
  auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2));
  if (at::inplaceIsVmapCompatible(gI, grad)) {
    gI *= (grad * grad_output);
  } else {
    gI = gI * (grad * grad_output);
  }

  if (isDefined(weight)) {
    gI *= *weight;
  }
  if (reduction == at::Reduction::Mean) {
    return gI / input.numel();
  } else if (reduction == at::Reduction::Sum) {
    return gI.sum();
  }
  return gI;
}

Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
  auto eps = 1e-12;
  // gradient wrt grad_output
  auto ggO = (input - target) / ((input + eps) * (1 - input + eps));
  if (at::inplaceIsVmapCompatible(ggO, grad)) {
    ggO *= grad;
  } else {
    ggO = ggO * grad;
  }

  if (isDefined(weight)) {
    ggO *= *weight;
  }
  if (reduction == at::Reduction::Mean) {
    return ggO / input.numel();
  } else if (reduction == at::Reduction::Sum) {
    return ggO.sum();
  }
  return ggO;
}

Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
  auto output = l1_loss_backward(grad, input, target, at::Reduction::None);
  if (reduction == at::Reduction::Mean) {
    return output.mean();
  } else if (reduction == at::Reduction::Sum) {
    return output.sum();
  }
  return output;
}

Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
  // special case to protect against a divide-by-zero.
  if (beta == 0) {
      return at::zeros(grad.sizes(), grad.options());
  }
  auto d = (input - target).abs();
  auto grad_input = grad * (d < beta).type_as(grad) / beta;
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.numel();
  }
  return grad_input;
}

Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
  if (reduction == at::Reduction::None) {
    return smooth_l1_loss_backward(grad, input, target, reduction, beta);
  }
  auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction, beta);
  return (r * grad).sum();
}

Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) {
  auto grad_input = 2 * grad;
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.numel();
  }
  return grad_input;
}

Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
  if (reduction == at::Reduction::None) {
    return mse_loss_backward(grad, input, target, reduction);
  }
  auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction);
  return (r * grad).sum();
}

Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
  auto z = (input * -target).exp();
  auto zplus1 = z + 1;
  auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.numel();
  }
  return grad_input;
}

Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
  if (reduction == at::Reduction::None) {
    return soft_margin_loss_backward(grad, input, target, reduction);
  }
  auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction);
  return (r * grad).sum();
}

Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) {
  auto x = (input * beta);
  return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * beta;
}


// NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
//
// `storage_offset` is ignored for simplicity in this note. If you just want the
// full algorithm without explanation, scroll down to bottom of this note.
//
// Implementing the backward of as_strided is tricky because you have to deal
// with mappings that map one memory location to multiple indices, i.e., the
// output tensor has multiple indices pointing to **overlapping** memory
// addresses. This can happen in all in all sorts of weird cases. For example,
//
//   x = torch.randn(15)
//   x.as_strided([3, 3], [1, 0])  # "expand" case
//   x.as_strided([3, 3], [2, 1])  # "size too large" case
//   x.as_strided([3, 2], [3, 6])  # res[2, 0] points to 2*3 + 0*6 = 6
//                                 # res[0, 1] points to 0*3 + 1*6 = 6
//
// Here is the general strategy we apply in implementing as_strided backward:
//   0. ??? (optimization step. we will talk about this later)
//   1. Create some underlying flattened tensor as if it is the base tensor
//      representing the contiguous memory storage for both input and output.
//   2. Use the output geometry to scatter (or index_add) the gradients into
//      this storage tensor.
//   3. ??? (fix for input tensor with overlapping memory. we will talk about
//           this later)
//   4. Return the as_strided view of the storage tensor using input geometry.
//
// In step (2), if the output tensor does't have overlapping memory, we can
// safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
// otherwise, we must use `index_add` as gradients at different indices may need
// to be summed to a single location.
//
// For example, in this case:
//
//   x = torch.randn(3)
//   y = x.as_strided([3, 3], [1, 0])  # "expand" case
//                                     # size   [ 3, 3]
//                                     # stride [ 1, 0]
//   y.backward()  # step (1): contiguous storagte tensor `s` of size 3, which
//                             is large enough to be used as underlying storage
//                             for `x` and `y`.
//                               s = [ 0, 0, 0]
//                 # step (2): since `y` has overlapping memory, index_add grad
//                             into `s` basing on `y`'s geometry, i.e.,
//                             s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
//                               s = [ 3, 3, 3]
//                 # step (4): as_strided view `s` using `x`'s geometry
//                               s = [ 3, 3, 3]
//                               grad_input = s.as_strided(x.size(), x.stride())
//                                          = s.as_strided([3], [1])
//                                          = [ 3, 3, 3]
//
// This is exactly what we would get if using `expand`. However, here the input
// tensor doesn't have overlapping memory. If it does, we must add an extra step
// before (4). Considering this case:
//
//   t = torch.randn(3)
//   x = t.expand(3, 3)            # input with overlapping memory
//                                 # size   [3, 3]
//                                 # stride [0, 1]
//   y = x.as_strided([1], [1])    # contiguous output
//                                 # size   [1]
//                                 # stride [1]
//   y.backward()  # step (1): contiguous storage tensor `s` of size 3, which
//                             is large enough to be used as underlying storage
//                             for `x` and `y`.
//                               s = [ 0, 0, 0]
//                 # step (2): scatter grad into `s` basing on `y`'s geometry
//                               s = [ 1, 0, 0]
//                 # step (4): as_strided view `s` using `x`'s geometry
//                               s = [ 1, 0, 0]
//                               grad_input = s.as_strided([3, 3], [0, 1])
//                                          = s.as_strided([3, 3], [0, 1])
//                                          = [[ 1, 0, 0],
//                                             [ 1, 0, 0],
//                                             [ 1, 0, 0]]
// Is this result correct?
//
// `x.as_strided([1], [1])` call is obviously equivalent with
// `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
// gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
// case, indexing `x` at any index in first column is also equivalent, and
// yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
// an `x.size(1)`-times difference between these gradients computed from other
// PyTorch ops and the gradient we got from as_strided.
//
// You might conclude that the gradients from as_strided is wrong. However,
// let's first see why they are actually reasonable. Consider the pointwise
// perturbations by `delta` anywhere in the first column of `x`. It will lead to
// a `delta` change in the same memory location, and then `y` will change by
// `delta`. So one can say the gradient should be exactly 1 at the first column,
// as given by our above procedure.
//
// In the above computation of numerical gradients, they only match the
// analytical results because strides and memory locations are considered in the
// forward pass, i.e., this op (including both forward and backward) is
// layout-aware.
//
// However, in PyTorch, most (probably all) other ops (forward and backward) are
// layout-agnostic. E.g.,
//
//   t = torch.randn(1)
//   x = t.expand(2)
//   y = x.sum()
//   y.backward()
//
// Layout-agnostic autograd (as it is currently in PyTorch) will give you
//
//   gy = 1
//   gx = [ 1, 1]  # SumBackward:    torch.ones_like(x)
//   gt = [ 2]     # ExpandBackward: gx.sum()
//
// Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
// (the other will also change by `delta`), `y` will change by `2 * delta`. So
// the gradients, if strides are taken into consideration, should be 2.
//
// Layout-aware autograd should give you
//
//   gy = 1
//   gx = [ 2, 2]  # Because the backward considers the fact that the input `x`
//                 # is already expanded.
//   gt = [ 2]     # Layout-aware backward of expand is just a slicing because
//                 # the previous backward should have already taken care of
//                 # strides and made sure that gradients are the same along the
//                 # expanded dimension.
//
// As shown above, these two types are not compatible. Therefore, we must either
// make as_strided layout-agnostic, or make all other ops layout-aware.
//
// It is difficult to support layout-aware autograd (at least in the current
// codebase structure), because it would mean
//   1. storing tensor geometries of every input tensor for backward
//   2. depending on input geometry, the gradient computed from backward change
//   3. ideally enforcing gradient of T to always have same strides as T
// (although these two methods only differ when it comes to overlapping memory)
//
// Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
// giving the same output regardless of the input layout. We consider
// `input.stride()` as a separate independent fixed argument `input_stride`.
// Then, `as_strided(input, size, stride)` can be thought of as:
//   1. "Scatter" each value of `input` into a "storage" using storage location
//      computed from the value's index in `input`, `input.size()` and
//      `input_stride`, but if N values end up in the same location, the value
//      is average of those N values (they will be the same value anyways).
//
//      Formal description:
//        Denote the set of all input indices that pointing to the same storage
//        location `storage[n]` as `S(n)`, i.e.,
//
//            S(n) = { index : <index, input_stride> == n, index is valid given input.size() },
//
//        where `<x, y>` is the dot product between `x` and `y`.
//
//        Then, the process is:
//
//            storage[n] = Avg { S(n) }
//
//        Note that all values in `S(n)` are the same (they point to the same
//        memory location anyways, so this step doesn't change anything, but
//        effectively avoids having the denpendency on the layout of `input`.
//        I.e., the result holds fixed regardless of the layout of `input`, as
//        long as `input_stride` is fixed.
//
//      NOTE: for forward pass, we can equivalently simply selet any one of
//            `S(n)` as `storage[n]`. However, cosnidering this as an average
//            operation makes backward easier (so all values in set
//            `{ grad_input[i] : i in S(n) }` are the same, and it can use the
//            same geometry as input).
//   2. As usual, return the as_strided view of `storage` using required output
//      `size` and `stride`.
//
// To backward through this layout-agnostic version, we simply add the following
// step:
//   .... (scatter gradients into the storage tensor using output geometry)
//   3. For all storage location n, `storage[n] /= |S(n)|`.
//   .... (return as_strided view of the storage tensor using input geometry)
//
// Finally, we note that these general operations are expensive, so we apply the
// following optimizations:
//   Add step (0): For all output dimension `d` with output stride 0, sum the
//                 gradients along dimension `d` (don't keepdim), and remove
//                 dimension `d` from output size and stride.
//                 (An optimization for "expand" cases so we may avoid step (3))
//  Only apply step (3) when input tensor has overlapping memory.
//
// FULL ALGORITHM:
//   0. For all output dimension `d` with output stride 0, sum the gradients
//       along dimension `d` (don't keepdim), and remove dimension `d` from
//       output size and stride.
//   1. Create some underlying flattened tensor as if it is the base tensor
//      representing the contiguous memory storage for both input and output.
//   2. Use the output geometry to scatter (or index_add) the gradients into
//      this storage tensor `storage`.
//   3. If input tensor has overlapping memory,
//      For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
//      number of indices in input geometry pointing to the same storage
//      location `i` (i.e., `|S(i)|` in equations above).
//   4. Return the as_strided view of the storage tensor using input geometry.
//
// See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
// roughly detech overlapping memory.


// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
//
// Checking memory overlap within a strided tensor is the special case of
// detecting memory overlap of two strided tensors, where the two tensors start
// at the same memory address. The later is HARD (see #8212).
//
// But even this special case isn't simple. This note describes a check for a
// even more constrained simple case where we can be certain that there is no
// overlap.
//
// The checking algorithm can be described as:
//   0. Return [ pass check ] if any dimension has size 0
//   1. Ignore all dimensions that have size 1
//   2. If no remaining dimensions, return [ pass check ]
//   3. Sort the remaining dimensions according to the strides decreasingly
//   4. Check that for each dimension k,
//
//           stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
//
//      That is equivalent to, after reordering the dimensions so strides are
//      in decreasing order, checking that stride of each dimension is larger
//      than the maximum memory offset in a slice at that dimension.
//
// Obviously this check passes for contiguous tensors ( the dimensions will be
// already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
// than RHS ). Similarly, the check passes for tensors contiguous in all but
// the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
// exactly stride[-1] larger than RHS. (*)
//
// We will show that these view operations, including all our view operations
// *except for* general as_strided and unfold, also preserve this invariant:
//
//  alias:      Obviously preserves
//
//  expand:     All changed dimensions are removed in step (1)
//
//  view:       Consider the input dimensions as grouped into consecutive
//              dimension "blocks", where dimensions are contiguous in each one.
//              one. view only works when the output dimensions can also be
//              grouped into the same consecutive blocks of same ordering.
//
//              NB: this means that the number of elements and stride of the
//                  last dimension in each block is the same in input and
//                  output. (**)
//
//              Notation:
//                Consider a single such block B,
//                    ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ...
//                                start--^^^^                  ^^^^^^^^^^^^--end
//                Each B[i] denotes a dimension index such that B[i] = B[0] + i.
//
//              We first show that in a tensor (i.e., input) satisfies the
//              invariant, after sorting, the dimensions within each block
//              still remain consecutive. (***)
//
//                After removing dimensions of size 1, the dimensions within a
//                block is already sorted by strides in descending order. So
//                sorting all dimensions will not change the relative ordering
//                among them.
//
//                Assume that some block B is not consecutive after sorting,
//                i.e., there exists a dimension d between B[0] and B[-1] in
//                sorted order.
//
//                By (*), we know that
//                       stride[B[0]]
//                    =  \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]]
//                    <  \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] + stride[d]
//                    <= \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d]
//                    <= \sum{j > B[0]} (size[j]    - 1) * stride[j],
//
//                where the first <   comes from sorting and
//                      the second <= comes from the fact that dimension d
//                                               exists after step (1) and
//                                               thus must have size greater
//                                               than 1
//                      the third  <= comes from the fact that each term in
//                                               the sum is non-negative
//
//                Then we have a countradiction as the invariant must not be
//                satisfied at B[0]. So the original proposition is true.
//
//              Now that we established the above claim (***), we consider the
//              view operation as first sorting the dimensions (i.e., blocks),
//              apply the original view (since it only cares dimensions being
//              consecutive and contiguous withtin each block), and then undo
//              the sort.
//
//              Consider a single block B in the output,
//                  ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
//                    start--^^^^                  ^^^^^^^^^^^^--end
//
//              By (*), we know that for all i
//                  stride[i] = stride[B[-1]] +
//                                \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]]
//
//              Then the invariant is obviously satisfied at every dimension
//              in this block if it is satisfied at dimnesion B[-1]. It only
//              remains to show that it is satisfied at the last dimension in
//              each block.
//
//              Since the same blocks are present in both input and output
//              with the same ordering, we will abuse the notation in the
//              following statements.
//
//              By (*), we know that the following holds for both input and
//              output, for any block B:
//                    \sum_{i > B[-1]} (size[i] - 1) * stride[i]
//                  = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]]
//                  = \sum_{block B' after B} numel(B') * stride[B'[-1]].
//                    ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
//              By (**), we know that, this quantity in the above equation
//              remains the same in input and output. So both
//                  \sum_{i > B[-1]} (size[i] - 1) * stride[i]
//              and
//                  stride[B[-1]]
//              are the same in input and output.
//
//              These two quantities are exactly the LHS and RHS of the
//              invariant inequality. Since by assumption the invariant is
//              satisfied in input at B[-1], it is also satisfied in output at
//              B[-1]. This concludes the proof.
//
//  squeeze:    Special case of view
//
//  unsqueeze:  Special case of view
//
//  slice:      Consider slicing dimension i with step = k >= 1.
//
//              Let stride' and size' be the output strides and sizes. We have
//
//                  stride'[i] = k * stride[i]
//                  size'[i] <= floor(size[i] / k)
//
//              If size'[i] = 1, invariant is obviously satisfied as we are
//              just removing a dimension (afte step (1)).
//
//              Assume size'[i] > 1.
//
//              By assumption, the invariant is satisfied at every dimension
//              in input.
//
//              For any dimension j, if stride[j] > stride[i], we have
//                  stride'[j] =  stride[j]
//                             >  (size[i] - 1) * stride[i]
//                             =  (size[i] / k * k - 1) * k * stride[i] / k
//                             =  (size[i] / k - 1 / k) * stride'[i]
//                             >= (size'[i]    - 1 / k) * stride'[i]
//                             >= stride'[i].
//
//              If stride[j] < stride[i], we have
//                  stride'[j] = stride[j] < stride[i] <= stride'[i].
//
//              So the sorting order remains unchanged after slice.
//
//              Since
//                     (size'[i] - 1) * stride'[i]
//                  =  (floor(size[i] / k) - 1) * k * stride[i]
//                  <= (size[i] / k - 1) * k * stride[i]
//                  =  (size[i] - k) * stride[i]
//                  <= (size[i] - 1) * * stride[i],
//              the term from this dimension i in the invariant inequality at
//              other dimensions can only decrease after slice. So the
//              invariant is preserved.
//
//  narrow:     Special case of slice
//
//  select:     narrow + squeeze
//
//  permute:    Sorting makes permutation of dimensions irrelevant
//
//  transpose:  Sorting makes swapping dimensions irrelevant
//
//  diagonal:   Effectively merging two dimensions i and j into a new
//              dimension k s.t.
//                  stride'[k] =  stride[i] + stride[j]
//                  size'[k]   <= min(size[i], size[j]),
//              where stride and size are on the input, and stride' and size'
//              are on the output.
//
//              Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
//              then this is unsqueeze on that dimension.
//
//              WLOG, say stride[i] >= stride[j].
//
//              Each dimension d in input with stride[d] > stride[j] has
//                  stride'[d] =  stride[d]
//                             >  (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j]
//                             >= stride[i] + stride[j]
//                             =  stride[k].
//              So, considering the sorted dimensions, this is effectively
//              removing i, and replacing j with k.
//
//              For dimensions d with stride[i] < stride[d] < stride[j], the
//              term from dimension i is removed in the invariant inequality.
//              For dimensions d with stride[d] > stride[j], we have
//                     (size'[k] - 1) * stride'[k]
//                  <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
//                  <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
//              so the term from i and j in the invariant can only decrease.
//
//              So this is generally relaxing the constraint, and thus it
//              preserves it.

// This implements steps (2)~(4) of the algorithm in
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// Helper for as_strided_backward
static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef strides) {
  if (sizes.size() > 0) {
    std::vector<std::size_t> argsort(sizes.size());
    std::iota(argsort.begin(), argsort.end(), 0);
    std::sort(argsort.begin(), argsort.end(),
        [&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; });

    int64_t max_index_in_slice = 0;
    for (auto i : argsort) {
      auto stride_ = strides[i];
      if (stride_ <= max_index_in_slice) {
        return true;
      }
      max_index_in_slice += stride_ * (sizes[i] - 1);
    }
  }
  return false;
}

// Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset
// Helper for as_strided_backward
static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
  int64_t storage_size = storage_offset + 1;
  int64_t dim = sizes.size();
  for (int64_t i = 0; i < dim; i++) {
    auto size_i = sizes[i];
    if (size_i == 0) {
      return storage_offset;
    }
    storage_size += (size_i - 1) * strides[i];
  }
  return storage_size;
}

// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_) {
  // For output geometry,
  //   check for size 0 dimensions,
  //   skip size 1 dimensions,
  //   reduce grad on expanded dims (stride=0, size>1)
  // Step (0)     for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
  // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on output geometry
  auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset());
  auto odim = grad.dim();
  std::vector<int64_t> out_sizes_, out_strides_;
  out_sizes_.reserve(odim);
  out_strides_.reserve(odim);
  for (int64_t i = odim - 1; i >= 0; i--) {
    auto size_i = sizes[i];
    auto stride_i = strides[i];
    if (size_i == 0) {
      return at::zeros(input_geometry.sizes(), grad.options());
    } else if (size_i == 1) {
      grad = grad.squeeze(i);
    } else if (stride_i == 0) {
      grad = grad.sum(i, false);
    } else {
      out_sizes_.insert(out_sizes_.begin(), size_i);
      out_strides_.insert(out_strides_.begin(), stride_i);
    }
  }
  // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on output geometry
  auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);

  // For input geometry,
  //   check for size 0 dimensions,
  //   skip size 1 dimensions,
  // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on input geometry
  auto idim = input_geometry.dim();
  IntArrayRef inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides();
  std::vector<int64_t> inp_sizes_, inp_strides_;
  inp_sizes_.reserve(idim);
  inp_strides_.reserve(idim);
  for (int64_t i = idim - 1; i >= 0; i--) {
    auto size_i = inp_sizes[i];
    auto stride_i = inp_strides[i];
    if (size_i == 0) {
      return at::zeros(input_geometry.sizes(), grad.options());
    } else if (size_i != 1) {
      inp_sizes_.insert(inp_sizes_.begin(), size_i);
      inp_strides_.insert(inp_strides_.begin(), stride_i);
    }
  }
  // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on input geometry
  auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);


  // Rest of this function implements
  // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
  // TODO: Raise if not all output values are visible in input geometry.
  //       Technically speaking, if you treat those values as constants, not
  //       raising is fine, and mathematically correct. However, these values
  //       really are contained in some base tensor, and by treating them as
  //       constants we are ignoring this tight dependency. Therefore, it is
  //       more sensible to raise here.

  // Step (1): create underlying tensor as "storage"
  auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset);
  auto inp_effective_offset = input_geometry.storage_offset() - shared_offset;
  auto out_effective_offset = storage_offset - shared_offset;
  auto base_size = std::max(
    _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
    _min_storage_size(out_sizes_, out_strides_, out_effective_offset)
  );
  auto storage = grad.new_zeros({base_size});

  // prepare indices tensor if we will do index_add_ later
  c10::optional<at::Tensor> flatten_full_indices;
  if (inp_maybe_overlap || out_maybe_overlap) {
    flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
  }

  // Step (2): use output geometry to scatter gradients into storage
  if (out_maybe_overlap) {
    auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset);
    storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
  } else {
    // assume that new tensors have 0 storage offset
    storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad);
  }

  // Step (3): if input tensor has overlapping memory, divide scattered gradient
  //           at storage[i] by the number of times i shows up in input geometry
  if (inp_maybe_overlap) {
    auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1);
    count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
    storage.div_(count); // this will give nan outside visible range
  }
  // Step (4): return as_strided view of the storage tensor with input geometry
  return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
}

std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
  if (!grad.defined()) {
    return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
  }
  auto recip = (self * self + other * other).reciprocal();
  return std::tuple<Tensor,Tensor>{
            output_mask[0] ? grad * other * recip : Tensor(),
            output_mask[1] ? grad * -self * recip : Tensor() };
}

// TODO: Seriously consider writing the derivative formulas for
// each output separately; there is not all that much sharing
// of computation going on here.
std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
    const Tensor & grad_grad_input,
    const Tensor & grad_grad_weight,
    const Tensor & grad_out,
    const Tensor & input_,
    const Tensor & weight_) {

  if (!(grad_grad_input.defined() || grad_grad_weight.defined() || grad_out.defined())) {
    return std::tuple<Tensor, Tensor, Tensor>(Tensor(), Tensor(), Tensor());
  }
    auto input = input_.contiguous();
    auto weight = weight_.contiguous();

  // Zero-fill undefined grads (TODO: do this more efficiently)
  auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

  auto positive_mask = (input > 0).type_as(ggI);
  auto nonpositive_mask = (input <= 0).type_as(ggW);

  // Explanation: Let input be i, weight be w, grad_output be gO.
  // f(i, w) = i      if i > 0
  //         = w * i  if i <= 0
  // gI = df/di * gO  = gO      if i > 0    gW = df/dw * gO = 0       if i > 0
  //                  = gO * w  if i <= 0                   = gO * i  if i <= 0
  // The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly.

  if (weight.numel() == 1) {
      // from PReLU.forward: num_parameters == 0 is used indicate that a
      // single weight is shared among all input channels.

      // this is a little tricky because PReLU currently doesn't take a shape so the weight may be
      // 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable
      // and tensor are merged).  So, use weight and ggW as 0-dim in this case.
      bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1);
      auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight;
      auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW;

      auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input);
      auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input);
      return std::tuple<Tensor, Tensor, Tensor>(
                ggO,
                ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask,
                (ggI * gO * nonpositive_mask).sum().expand_as(weight)
          );
  } else {
      // Expand ggW to match size of ggI; a simple expand doesn't work because
      // ggW is the size of the input channel (dim==1 unless there is only 1 dimension).  For example,
      // let ggI be size (3,4,5,6,7) and ggW be size (4).  Then we unsqueeze ggW to be size (4,1,1,1)
      // so the expand succeeds.
      auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0);
      auto ggW_expanded = ggW;
      for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
          ggW_expanded = ggW_expanded.unsqueeze(1);
      }
      ggW_expanded = ggW_expanded.expand_as(ggI);

      auto gI = ggW_expanded * gO * nonpositive_mask;

      auto gW = ggI * gO * nonpositive_mask;
      if (input.dim() > 1) {
          gW = gW.sum(0);
      }
      while (gW.dim() > 1) {
          gW = gW.sum(1);
      }

      Tensor ggO;
      if (gO.requires_grad()) {
          // expand weight as input as in ggW/ggI above
          auto weight_expanded = weight;
          for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
              weight_expanded = weight_expanded.unsqueeze(1);
          }
          weight_expanded = weight_expanded.expand_as(input);

          auto mask = positive_mask + nonpositive_mask * weight_expanded;
          ggO = ggI * mask + ggW_expanded * nonpositive_mask * input;
      }
      return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
  }
}

// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
          bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
  TORCH_CHECK(compute_uv,
           "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
           "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");

  auto m = self.size(-2);
  auto n = self.size(-1);
  auto k = sigma.size(-1);
  auto gsigma = grads[1];

  auto u = raw_u;
  // Currently torch.svd for complex dtypes returns the conjugate of V,
  // while the backward formula is derived with just V (without the conjugation)
  // therefore here we need to conjugate the V output of SVD and grads[2].
  // Once https://github.com/pytorch/pytorch/issues/45821 is resolved
  // extra .conj(), that are marked below in the code, shall be removed.
  auto v = raw_v.conj();  // TODO: remove .conj()
  auto gu = grads[0];
  auto gv = grads[2].conj();  // TODO: remove .conj()

  if (!some) {
    // We ignore the free subspace here because possible base vectors cancel
    // each other, e.g., both -v and +v are valid base for a dimension.
    // Don't assume behavior of any particular implementation of svd.
    u = raw_u.narrow(-1, 0, k);
    v = raw_v.narrow(-1, 0, k).conj();  // TODO: remove .conj()
    if (gu.defined()) {
      gu = gu.narrow(-1, 0, k);
    }
    if (gv.defined()) {
      gv = gv.narrow(-1, 0, k);
    }
  }
  auto vh = v.conj().transpose(-2, -1);

  Tensor sigma_term;
  if (gsigma.defined()) {
    gsigma = gsigma.to(self.dtype());
    // computes u @ diag(gsigma) @ vh
    sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh);
  } else {
    sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }
  // in case that there are no gu and gv, we can avoid the series of kernel
  // calls below
  if (!gv.defined() && !gu.defined()) {
    return sigma_term;
  }

  auto uh = u.conj().transpose(-2, -1);
  auto im = at::eye(m, self.options());
  auto in = at::eye(n, self.options());
  auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
  auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
  auto sigma_sq = sigma.pow(2);
  auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1);
  // The following two lines invert values of F, and fills the diagonal with 0s.
  // Notice that F currently has 0s on diagonal. So we fill diagonal with +inf
  // first to prevent nan from appearing in backward of this function.
  F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
  F = F.pow(-1);

  Tensor u_term, v_term;

  if (gu.defined()) {
    auto guh = gu.conj().transpose(-2, -1);
    u_term = at::matmul(u, at::matmul(F.mul(at::matmul(uh, gu) - at::matmul(guh, u)), sigma_mat));
    if (m > k) {
      u_term = u_term + at::matmul(im - at::matmul(u, uh), at::matmul(gu, sigma_mat_inv));
    }
    u_term = at::matmul(u_term, vh);
  } else {
    u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }

  if (gv.defined()) {
    auto gvh = gv.conj().transpose(-2, -1);
    v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh));
    if (n > k) {
      v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvh, in - at::matmul(v, vh)));
    }
    v_term = at::matmul(u, v_term);
  } else {
    v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }

  // for complex-valued input there is an additional term
  // https://giggleliu.github.io/2019/04/02/einsumbp.html
  // https://arxiv.org/abs/1909.02659
  if (self.is_complex() && gu.defined()) {
    // computes L = Identity.mul(uh @ gu)
    Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1).diag_embed(0, -2, -1);
    L = L - L.conj().transpose(-2, -1);
    Tensor imag_term = 0.5 * at::matmul(at::matmul(at::matmul(u, L), sigma_mat_inv), vh);
    return u_term + sigma_term + v_term + imag_term;
  }

  return u_term + sigma_term + v_term;
}

// "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation"
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
                    bool eigenvectors, const Tensor& lambda, const Tensor& v) {
  // This gradient only works for real eigenvalues at the moment.
  TORCH_CHECK(eigenvectors,
           "eig_backward: Setting eigenvectors to false in torch.eig doesn't compute eigenvectors ",
           "and hence we cannot compute backward. Please use torch.eig(eigenvectors=True)");
  auto zeros = at::zeros({1}, lambda.options());
  TORCH_CHECK(
      at::allclose(lambda.slice(/*dim=*/-1, /*start=*/1, /*end=*/2), zeros),
      "eig_backward: Backward calculation does not support complex eigenvalues at the moment.");

  auto glambda = grads[0];
  auto gv = grads[1];
  auto vt = v.transpose(-2, -1);

  Tensor result;
  // contribution from the eigenvectors
  if (gv.defined()) {
    auto rlambda = lambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1);

    auto hm = rlambda.transpose(-2,-1) - rlambda;
    hm.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
    hm.pow_(-1.0);

    auto gvortho = gv - at::sum(gv * v, /*dim=*/-2, /*keepdim=*/true) * v;
    auto B = hm * at::matmul(vt, gvortho);
    auto A = at::matmul(B, vt);

    std::tie(result, std::ignore) = at::solve(A, vt);
  }
  // contribution from eigenvalues
  if (glambda.defined()) {
    auto grlambda = glambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1) * vt;
    auto A = at::matmul(v, grlambda);
    auto vvt = at::matmul(v, vt);
    if (result.defined()) {
      Tensor result1;
      std::tie(result1, std::ignore) = at::solve(A, vvt);
      result = result.add(result1);
    }
    else {
      std::tie(result, std::ignore) = at::solve(A, vvt);
    }
  }
  return result;
}

// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
                    bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) {
  // This gradient is symmetric, and not triangular.
  // symeig operates only on symmetric inputs, which is a subspace of
  // R^{n x n}, and hence the derivative is not well-defined for off-diagonal
  // elements. We resolve this by taking the gradient of the functionally independent
  // elements of the matrix (i.e., the lower triangular portion of the input) and then
  // reflect it on the upper triangular portion, thereby symmetrizing the gradient of
  // the symeig operation. The motivation behind this choice is that symmetric gradient
  // leads to stable gradient updates, and retains symmetry of the updated matrix if it
  // were updated by a gradient based algorithm.
  TORCH_CHECK(eigenvectors,
           "symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ",
           "and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)");

  auto glambda = grads[0];
  auto gv = grads[1];

  auto vh = v.conj().transpose(-2, -1);

  Tensor result;
  if (gv.defined()) {
      Tensor F = lambda.unsqueeze(-2) - lambda.unsqueeze(-1);
      F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
      F.pow_(-1);
      result = at::matmul(v, at::matmul(F * at::matmul(vh, gv), vh));
  } else {
      result = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }

  if (glambda.defined()) {
    glambda = glambda.to(self.dtype());
    // computes v @ diag(glambda) @ vh
    Tensor glambda_term = at::matmul(v * glambda.unsqueeze(-2), vh);
    if (at::inplaceIsVmapCompatible(result, glambda_term)) {
      result.add_(glambda_term);
    } else {
      result = result + glambda_term;
    }
  }
  return result.add(result.conj().transpose(-2, -1)).mul_(0.5);
}

Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
                   bool some, const Tensor& q, const Tensor& r){
  auto square_deep_case_backward = [](const Tensor& grad_Q,
                                      const Tensor& grad_R,
                                      const Tensor& A,
                                      const Tensor& Q,
                                      const Tensor& R) -> Tensor {
    // For square and deep (tall) case we refer:
    // Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra.
    // https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition)
    // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks.
    // https://arxiv.org/abs/1903.09650 Section 3. QR factorization
    // For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html

    // Compute R grad_R^H
    Tensor R_term;
    if (grad_R.defined()) {
      R_term = at::matmul(R, grad_R.conj().transpose(-2, -1));
    } else {
      // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
      R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    }

    // Compute grad_Q^H Q
    Tensor Q_term;
    if (grad_Q.defined()) {
      Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q);
    } else {
      // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
      Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    }

    Tensor M = R_term - Q_term;

    // Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity
    Tensor M_tril = at::tril(M);
    M = M_tril + M_tril.conj().transpose(-2, -1);
    M.diagonal(0, -2, -1).mul_(0.5);

    Tensor rhs_term;
    if (grad_Q.defined()) {
      rhs_term = grad_Q + at::matmul(Q, M);
    } else {
      rhs_term = at::matmul(Q, M);
    }

    // We want to compute: (rhs_term @ R^{-H})
    // Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H
    // Since R is upper triangular, we can do this using
    // triangular_solve(rhs_term^H, R)^H
    Tensor grad_A;
    std::tie(grad_A, std::ignore) = at::triangular_solve(
        rhs_term.conj().transpose(-2, -1),
        R,
        /*upper=*/true,
        /*transpose=*/false,
        /*unitriangular=*/false);

    return grad_A.conj().transpose(-2, -1);
  };

  auto m = self.size(-2);
  auto n = self.size(-1);

  TORCH_CHECK(
      ((m <= n && (!some)) || some),
      "The derivative is not implemented when nrows > ncols and complete QR. ");

  auto grad_Q = grads[0];
  auto grad_R = grads[1];

 if (m >= n) {
    return square_deep_case_backward(grad_Q, grad_R, self, q, r);
  } else {
    // For wide (m < n) input matrices A,  partition A = [X|Y] and R = [U|V]
    // X and U are square full rank matrices. We will partition grads,
    // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y].
    // To obtain grad_X we reuse the gradient formula from the square case.
    // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U),
    // where grad_Q_prime = grad_Q + Y @ grad_V^H
    // and grad_Y = Q @ grad_V.
    // Then concatenate grads to get grad_A = [grad_X | grad_Y].

    auto Y = self.narrow(-1, m, n - m);
    auto U = r.narrow(-1, 0, m);
    Tensor grad_Y, grad_X, grad_V, grad_Q_prime;

    if (grad_R.defined()) {
      grad_V = grad_R.narrow(-1, m, n - m);
      // reuse grad_R to store grad_U
      grad_R = grad_R.narrow(-1, 0, m);
      // grad_Q_prime starts with the value of Y @ grad_V^H
      grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1));
    } else {
      // when grad_R is not defined then grad_V and grad_Q_prime
      // get initialized with zeros
      grad_V = at::zeros_like(Y, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
      grad_Q_prime = at::zeros_like(q, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    }

    if (grad_Q.defined()) {
      // add the grad_Q term into grad_Q_prime when defined o/w is 0
      grad_Q_prime = grad_Q_prime + grad_Q;
    }
    // Calculate grad_X using the helper. Grad_R contains the grad_U value
    grad_X = square_deep_case_backward(grad_Q_prime, grad_R, self, q, U);
    grad_Y = at::matmul(q, grad_V);
    // Concatenate grad_X and grad_Y to get grad_A.
    return at::cat({grad_X, grad_Y}, -1);
  }
}

// Invertible case is derived from Jacobi's formula, and also can be found at:
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
  auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
    Tensor u, sigma, v;
    std::tie(u, sigma, v) = self.svd();
    auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1));
    return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
  };

  auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
    return unsqueeze_multiple(grad * det, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
  };

  if (self.dim() == 2) {
    if (det.item<double>() == 0) {
      return singular_case_backward(grad, self, det);
    } else {
      return nonsingular_case_backward(grad, self, det);
    }
  } else {
    auto nonzero_det_indices = at::where(det);

    if (nonzero_det_indices[0].size(0) == det.numel()) {  // all determinants are nonzero (non-singular)
      return nonsingular_case_backward(grad, self, det);
    }

    auto zero_det_indices = at::where(det == 0);

    if (zero_det_indices[0].size(0) == det.numel()) {  // all determinants are zero (singular)
      return singular_case_backward(grad, self, det);
    }

    Tensor grad_det = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

    // invertible case
    grad_det.index_put_(/*indices=*/nonzero_det_indices,
                        /*value=*/nonsingular_case_backward(grad.index(nonzero_det_indices),
                                                            self.index(nonzero_det_indices),
                                                            det.index(nonzero_det_indices)));

    // non-invertible case, uses SVD
    grad_det.index_put_(/*indices=*/zero_det_indices,
                        /*value=*/singular_case_backward(grad.index(zero_det_indices),
                                                         self.index(zero_det_indices),
                                                         det.index(zero_det_indices)));

    return grad_det;
  }
}

Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
  auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
    Tensor u, sigma, v;
    std::tie(u, sigma, v) = self.svd();
    // logdet = \sum log(sigma)
    auto gsigma = grad.unsqueeze(-1).div(sigma);
    return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
  };

  auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
    return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
  };

  if (self.dim() == 2) {
    if (logdet.item<double>() != -INFINITY) {
      return nonsingular_case_backward(grad, self);
    } else {
      return singular_case_backward(grad, self);
    }
  } else {
    auto finite_logdet_indices = at::where(logdet != -INFINITY);

    if (finite_logdet_indices[0].size(0) == logdet.numel()) {  // all log determinants are finite (non-singular)
      return nonsingular_case_backward(grad, self);
    }

    auto neginf_logdet_indices = at::where(logdet == -INFINITY);

    if (neginf_logdet_indices[0].size(0) == logdet.numel()) {  // all log determinants are -inf (singular)
      return singular_case_backward(grad, self);
    }

    Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

    // invertible case
    grad_logdet.index_put_(/*indices=*/finite_logdet_indices,
                           /*value=*/nonsingular_case_backward(grad.index(finite_logdet_indices),
                                                               self.index(finite_logdet_indices)));

    // non-invertible case, uses SVD
    grad_logdet.index_put_(/*indices=*/neginf_logdet_indices,
                           /*value=*/singular_case_backward(grad.index(neginf_logdet_indices),
                                                            self.index(neginf_logdet_indices)));

    return grad_logdet;
  }
}

Tensor slogdet_backward(const Tensor& grad_logabsdet,
                        const Tensor& self,
                        const Tensor& signdet, const Tensor& logabsdet) {
  auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
    Tensor u, sigma, v;
    std::tie(u, sigma, v) = self.svd();
    // sigma has all non-negative entries (also with at least one zero entry)
    // so logabsdet = \sum log(abs(sigma))
    // but det = 0, so backward logabsdet = \sum log(sigma)
    auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
    return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
  };

  auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
    return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
  };

  if (self.dim() == 2) {
    if (signdet.item<double>() == 0) {
      return singular_case_backward(grad_logabsdet, self);
    } else {
      return nonsingular_case_backward(grad_logabsdet, self);
    }
  } else {
    auto nonzero_signdet_indices = at::where(signdet);

    if (nonzero_signdet_indices[0].size(0) == logabsdet.numel()) {  // all log determinants are finite (non-singular)
      return nonsingular_case_backward(grad_logabsdet, self);
    }

    auto zero_signdet_indices = at::where(signdet == 0);

    if (zero_signdet_indices[0].size(0) == logabsdet.numel()) {  // all log determinants are -inf (singular)
      return singular_case_backward(grad_logabsdet, self);
    }

    Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

    // invertible case
    grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices,
                            /*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices),
                                                                self.index(nonzero_signdet_indices)));

    // non-invertible case, uses SVD
    grad_slogdet.index_put_(/*indices=*/zero_signdet_indices,
                            /*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices),
                                                             self.index(zero_signdet_indices)));

    return grad_slogdet;
  }
}

// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
std::tuple<Tensor, Tensor> triangular_solve_backward(
    const Tensor & grad_x, const Tensor & grad_m,
    const Tensor & b, const Tensor & a, const Tensor & x,
    const bool upper, const bool transpose, const bool unitriangular,
    std::array<bool, 2> output_mask) {
  Tensor grad_b, grad_a;
  if (grad_x.defined() || grad_m.defined()) {
    if (grad_x.defined()) {
      grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
      if (output_mask[1]) {
        grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj());
        if (upper) {
          grad_a = grad_a.triu((int) unitriangular);
        } else {
          grad_a = grad_a.tril(-((int) unitriangular));
        }
      }
    }
    if (!grad_a.defined()) {
      grad_a = at::zeros({1}, a.options()).expand_as(a);
    }
    if (!grad_b.defined()) {
      grad_b = at::zeros({1}, b.options()).expand_as(b);
    }
    if (output_mask[1] && grad_m.defined()) {
      grad_a = grad_a.add(grad_m);
    }
  }
  return std::tuple<Tensor, Tensor>{grad_b, grad_a};
}

std::tuple<Tensor, Tensor> cholesky_solve_backward(
    const Tensor& grad_x, const Tensor& self,
    const Tensor& input2, const Tensor& result, const bool upper) {
  Tensor grad_self, grad_input2;
  if (grad_x.defined()) {
    grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);

    Tensor common_term = at::matmul(grad_self, result.conj().transpose(-2, -1));
    common_term = common_term + common_term.conj().transpose(-2, -1);

    if (upper) {
      grad_input2 = -at::matmul(input2, common_term);
    } else {
      grad_input2 = -at::matmul(common_term, input2);
    }
  }
  return std::tuple<Tensor, Tensor>{grad_self, grad_input2};
}

Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) {
  // Forward is C2R (onesided)
  // Think of onesided C2R irfft as
  //    1. fill the other half by conjugate symmetry
  //    2. inverse C2C ifft
  //    3. discard the complex dimension
  // So backward is
  //    1. R2C rfft (essentially add dummy complex dimension, and dft)
  //    2. accumulate gradient by conjugate symmetry
  //       since rfft results follow conjugate symmetry, we only need to
  //       double some entries from onesided rfft results, i.e., the ones with
  //       their reflected indices also landing out of the onesided range. So
  //       consider the index of last dim:
  //           i.   idx = 0.
  //                Reflected to (N - 0) % N = 0. Not doubled.
  //           ii   0 < idx < floor(N/2) (last).
  //                N > N - idx > ceil(N/2)
  //                Reflected to ()
  //           iii. idx = floor(N/2) = N/2 (last) when N even.
  //                Reflected to (N - N/2) % N = N/2. Not doubled.
  //           iv.  idx = floor(N/2) = (N-1)/2 (last) when N odd.
  //                Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
  //       Therefore, needs to double
  //           idx = 1, 2, ..., N/2 - 1     when N even
  //           idx = 1, 2, ..., (N-1)/2     when N odd
  //       that is
  //           idx = 1, 2, ..., N - (floor(N/2) + 1)
  //               = 1, 2, ..., N - onesided_length
  auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true);

  auto double_length = grad.size(dim.back()) - gI.size(dim.back());
  if (double_length > 0) {  // also covers case when signal size is zero
    gI.narrow(dim.back(), 1, double_length).mul_(2);
  }
  return gI;
}

Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization,
                        bool onesided, int64_t last_dim_size) {
  if (!onesided) {
    return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false));
  }

  // Forward is R2C (onesided)
  // Think of onesided R2C rfft as
  //     1. view as complex numbers (fill complex dim with zeros)
  //     2. C2C fft
  //     3. discard half of results
  // So backward is
  //     1. fill the other half with zeros (with `zero_grad_shape` below)
  //        (C2C ifft only take twosided inputs so we need to fill here)
  //     2. inverse C2C ifft
  //     3. discard the complex dim
  auto half_sizes = grad.sizes();
  at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end());
  const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size());
  new_grad_shape[last_dim] = last_dim_size;

  const auto zero_length = last_dim_size - grad.size(dim.back());
  auto complex_full_grad = zero_length > 0 ? at::zeros(new_grad_shape, grad.options()) : grad;
  if (zero_length > 0) {
    complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad);
  }
  return at::real(at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false));
}

// Helper for batchnorm_double_backward
Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) {
  auto r = to_sum.sum(0, keepdim);
  int64_t start_point_exclusive = keepdim ? 1 : 0;
  for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
    r = r.sum(dim, keepdim);
  }
  return r;
}

// Helper for batchnorm_double_backward
// similar to expand_as below, but doesn't do the expand_as; operates as if
// reductions were done with keepdim=True
Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
  auto src_expanded = src;
  while (src_expanded.sizes().size() < target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(1);
  }
  if (src_expanded.sizes().size() == target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(0);
  }
  return src_expanded;
}

// Helper for batchnorm_double_backward
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
// do a straight expansion because it won't follow the broadcasting rules.
Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
  auto src_expanded = src;
  while (src_expanded.sizes().size() < target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(1);
  }
  return src_expanded.expand_as(target);
}

std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
    const Tensor & input,
    const c10::optional<Tensor> & gamma,
    const Tensor & ggI,
    const Tensor & ggG,
    const Tensor & ggB,
    const Tensor & gO,
    const c10::optional<Tensor> & running_mean,
    const c10::optional<Tensor> & running_var,
    bool training,
    double eps,
    const c10::optional<Tensor> & save_mean,
    const c10::optional<Tensor> & save_invstd,
    std::array<bool,3> output_mask) {

  bool affine = isDefined(gamma);
  // TODO: Do we have a ScalarOrTensor type?  Would such a thing exist?
  Tensor gamma_expanded;
  Tensor ggG_expanded, ggB_expanded;
  if (affine) {
    gamma_expanded = expand_as_dim1(*gamma, input);
    if (ggG.defined()) {
      ggG_expanded = expand_as_dim1(ggG, input);
    }
    if (ggB.defined()) {
      ggB_expanded = expand_as_dim1(ggB, input);
    }
  } else {
    gamma_expanded = at::ones({}, input.options());
  }

  // define some terms we will reuse
  auto M = input.size(0);
  for (auto s : input.sizes().slice(2)) {
    M *= s;
  }
  // for half inputs, save_mean, save_invstd are float (ideally, we would cast
  // everything else, but not now)
  auto mu = unsqueeze_dim1(training ? toLegacyTensor(save_mean).to(input.scalar_type()) : toLegacyTensor(running_mean), input);
  auto input_sub_mu = input - mu;
  auto sigma2_eps_neg_1_2 = unsqueeze_dim1(
      training ? toLegacyTensor(save_invstd).to(input.scalar_type())
               : toLegacyTensor(running_var).add(Scalar(eps)).pow(-0.5),
      input);
  auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
  auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

  // calculate gI
  auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
  auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
  auto gO_sum = sum_exclude_dim1(gO);

  Tensor gI;
  if (ggI.defined() && training) {
    auto ggI_sum = sum_exclude_dim1(ggI);
    auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
    auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_(
                    (sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M));
    auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
    auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
    auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
    gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
  }

  // add contribution of gamma term to gI
  Tensor gI_G_term;
  if (affine && ggG.defined()) {
    if (training) {
      auto t0 = gO * sigma2_eps_neg_1_2;
      auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
      auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M);
      gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
      gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
    } else {
      gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
      gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
    }
  }

  // this is the first backward's grad_input
  auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor {
    auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
    auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_(
                input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu));
    return h0 * h1;
  };

  // calculate gG
  Tensor gG;
  if (affine && ggI.defined()) {
    if (training) {
      // gG is just the first backwards with the gamma term removed (then shaped properly)
      gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
      gG = sum_exclude_dim1(gG, false);
    } else {
      gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
    }
  }

  // calculate ggO
  Tensor ggO;
  // contribution of input term
  if (ggI.defined()) {
    if (training) {
      ggO = first_back_grad_input(ggI, gamma_expanded);
    } else {
      ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
    }
  }
  if (ggG.defined()) {
    auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
    ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
  }
  if (ggB.defined()) {
    auto ggO_B_term = ggB_expanded;
    ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
  }

  if (output_mask[1] && !gG.defined()) {
    AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
  }

  return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};

}

std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_layer_norm_backward(
    const Tensor& dY,
    const Tensor& dmean,
    const Tensor& drstd,
    const Tensor& X,
    const Tensor& mean,
    const Tensor& rstd,
    const c10::optional<Tensor>& gamma,
    IntArrayRef normalized_shape,
    double eps,
    std::array<bool, 3> grad_input_mask) {

  const int normalized_ndim = normalized_shape.size();
  const auto input_shape = X.sizes();
  const auto input_ndim = X.dim();
  const int axis = input_ndim - normalized_ndim;
  const int64_t M =
      at::prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis);
  const int64_t N =
      at::prod_intlist(input_shape.cbegin() + axis, input_shape.cend());

  Tensor dX;
  Tensor dgamma;
  Tensor dbeta;

  const Tensor X_tensor = X.reshape({M, N});
  const Tensor mean_tensor = mean.reshape({M, 1});
  const Tensor rstd_tensor = rstd.reshape({M, 1});
  const double s = 1.0 / static_cast<double>(N);

  Tensor dY_tensor;
  if (dY.defined()) {
    dY_tensor = dY.reshape({M, N});
  }

  if (grad_input_mask[0]) {
    Tensor gamma_tensor;
    if (isDefined(gamma)) {
      gamma_tensor = gamma->reshape({1, N});
    }
    Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
    Tensor var;
    Tensor dvar;
    if (drstd.defined()) {
      var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
      dvar = -0.5 * rstd_cube * drstd.view({M, 1});
    }
    Tensor ds;
    Tensor db;
    if (dY.defined()) {
      ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor
                            : dY_tensor * X_tensor)
               .sum(1)
               .unsqueeze_(-1);
      db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor)
               .sum(1)
               .unsqueeze_(-1);
      const Tensor& a = rstd_tensor;
      const Tensor b = (db * mean_tensor - ds) * rstd_cube * s;
      const Tensor c = -b * mean_tensor - db * rstd_tensor * s;
      if (isDefined(gamma)) {
        dX = a * dY_tensor * gamma_tensor + b * X_tensor + c;
      } else {
        dX = a * dY_tensor + b * X_tensor + c;
      }
      if (dmean.defined() && drstd.defined()) {
        dX += var_std_mean_backward(
            {dvar, dmean.view({M, 1})},
            X_tensor,
            var,
            mean_tensor,
            {1},
            false,
            true,
            false);
      }
      dX = dX.reshape_as(X);
    } else if (dmean.defined() && drstd.defined()) {
      dX = var_std_mean_backward(
               {dvar, dmean.view({M, 1})},
               X_tensor,
               var,
               mean_tensor,
               {1},
               false,
               true,
               false)
               .reshape_as(X);
    }
  }

  if (grad_input_mask[1] && dY.defined()) {
    dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor)
                 .sum(0)
                 .reshape_as(toLegacyTensor(gamma));
  }
  if (grad_input_mask[2] && dY.defined()) {
    dbeta = dY_tensor.sum(0).reshape_as(toLegacyTensor(gamma));
  }

  return std::make_tuple(dX, dgamma, dbeta);
}

std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
    const Tensor& dY,
    const Tensor& dmean,
    const Tensor& drstd,
    const Tensor& X,
    const Tensor& mean,
    const Tensor& rstd,
    const c10::optional<Tensor>& gamma,
    int64_t N,
    int64_t C,
    int64_t HxW,
    int64_t group,
    double eps,
    std::array<bool, 3> grad_input_mask) {
  const int64_t G = group;
  const int64_t D = C / G;
  const double s = 1.0 / static_cast<double>(D * HxW);
  Tensor dX;
  Tensor dgamma;
  Tensor dbeta;
  const Tensor X_tensor = X.reshape({N, G, D, HxW});
  const Tensor mean_tensor = mean.reshape({N, G, 1, 1});
  const Tensor rstd_tensor = rstd.reshape({N, G, 1, 1});
  Tensor dY_tensor;
  Tensor ds;
  Tensor db;
  if (dY.defined()) {
    dY_tensor = dY.reshape({N, G, D, HxW});
    ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1);
    db = dY_tensor.sum(3).unsqueeze_(-1);
  }
  if (grad_input_mask[0]) {
    Tensor gamma_tensor;
    if (isDefined(gamma)) {
      gamma_tensor = gamma->reshape({1, G, D, 1});
    }
    const Tensor var =
        ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
    const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
    Tensor dvar;
    if (drstd.defined()) {
      dvar = -0.5 * rstd_cube * drstd.view({N, G, 1, 1});
    }
    if (dY.defined()) {
      const Tensor a =
          isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor;
      Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2))
                     .unsqueeze_(-2);
      Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2))
                     .unsqueeze_(-2);
      b = (c * mean_tensor - b) * rstd_cube * s;
      c = -b * mean_tensor - c * rstd_tensor * s;
      dX = a * dY_tensor + b * X_tensor + c;
      if (dmean.defined() && drstd.defined()) {
        dX += var_std_mean_backward(
            {dvar, dmean.view({N, G, 1, 1})},
            X_tensor,
            var,
            mean_tensor,
            {2, 3},
            false,
            true,
            false);
      }
      dX = dX.reshape_as(X);
    } else if (dmean.defined() && drstd.defined()) {
      dX = var_std_mean_backward(
               {dvar, dmean.view({N, G, 1, 1})},
               X_tensor,
               var,
               mean_tensor,
               {2, 3},
               false,
               true,
               false)
               .reshape_as(X);
    }
  }
  if (grad_input_mask[1] && dY.defined()) {
    dgamma = ((ds - db * mean_tensor) * rstd_tensor).sum(0).reshape_as(toLegacyTensor(gamma));
  }
  if (grad_input_mask[2] && dY.defined()) {
    dbeta = db.sum(0).reshape_as(toLegacyTensor(gamma));
  }

  return std::make_tuple(dX, dgamma, dbeta);
}

std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
                                                       IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3,
                                                       IntArrayRef sumdim, int64_t unroll_dim, std::array<bool, 3> grad_mask) {
  Tensor grad_i1, grad_i2, grad_i3;
  if (grad_out.defined()) {
    if (grad_mask[0])
      grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1);
    if (grad_mask[1])
      grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2);
    if (grad_mask[2])
      grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3);
  }
  return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
}

Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
  if (self.is_sparse()) {
    AT_ERROR(
      "log1p of a sparse tensor is made to be non-differentiable since ",
      "local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ",
      "Use a different mathematical operation which preserves sparsity of gradients, ",
      "or report a bug if you think this is an error.");
  }
  return grad / (self + 1).conj();
}

Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntArrayRef values_shape) {
  // TODO: improve this backward by writing a kernel (maybe)
  auto dense_grad = sparse_grad_out.is_sparse() ? sparse_grad_out.to_dense() : sparse_grad_out;
  auto full_size = sparse_grad_out.sizes();
  auto flattened_grad_shape = values_shape.vec();
  flattened_grad_shape[0] = at::prod_intlist(full_size.slice(0, indices.size(0)));
  auto flattened_dense_grad = dense_grad.view(flattened_grad_shape);
  auto flattened_indices = at::sparse::flatten_indices(indices, full_size);
  return flattened_dense_grad.index_select(0, flattened_indices);
}

// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
  auto negated_pad = pad.vec();
  std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate<int64_t>());
  return at::constant_pad_nd(grad, negated_pad, 0);
}

Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
  // since first backward takes care of scaling by frequency,
  // we don't need to worry about it here.
  auto gg_weight = grad.index_select(0, indices.reshape(-1));

  // reshape gradient as per the shape of indices
  auto size = indices.sizes().vec();
  size.push_back(-1);

  if (padding_idx >= 0) {
    gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
  }
  return gg_weight.view(size);
}

Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor& grad) {
   return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
}

Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {
  if (zero_infinity) {
    return at::where(
        loss.unsqueeze(0).unsqueeze(2) == 0,
        at::zeros({0}, raw_grad.options()),
        raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
  } else {
    return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
  }
}

bool any_variable_defined(variable_list& variables) {
  for (auto variable : variables) {
    if (variable.defined()) {
      return true;
    }
  }
  return false;
}

} // namespace details
} // namespace generated
} // namespace autograd
} // namespace torch
